import torch

path = "../checkpoint/"

featnet_state = torch.load(path+"FeatNet_2000epoch.pth")
featnet_leye = featnet_state['featnet_leye']
featnet_reye = featnet_state['featnet_reye']
featnet_others = featnet_state['featnet_others']

featnet_prior = torch.load(path+"FeatNet_prior.pth")
leye_z_feat = torch.cat((featnet_prior['leye_z_mean'][-1].view(1,128), featnet_prior['leye_z_mean'][:8]), dim=0)
reye_z_feat = torch.cat((featnet_prior['reye_z_mean'][-1].view(1,128), featnet_prior['reye_z_mean'][:8]), dim=0)
others_z_feat = torch.cat((featnet_prior['others_z_mean'][3].view(1,128), 
                           featnet_prior['others_z_mean'][:3],
                           featnet_prior['others_z_mean'][4:]), dim=0)

coordnet_state = torch.load(path+"CoordNet_2000epoch.pth")
coordnet = coordnet_state['coordnet']

coordnet_prior = torch.load(path+"CoordNet_prior.pth")
leye_idx = [16, 0, 2, 4, 5, 8, 10, 12, 13]
leye_z_coord = coordnet_prior['landmark_z_mean'][leye_idx]
reye_idx = [17, 1, 3, 6, 7, 9, 11, 14, 15]
reye_z_coord = coordnet_prior['landmark_z_mean'][reye_idx]
others_idx = [21, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28]
others_z_coord = coordnet_prior['landmark_z_mean'][others_idx]

relcoordnet_state = torch.load(path+"RelCoordNet_2000epoch.pth")
relcoordnet = relcoordnet_state['relcoordnet']

relcoordnet_prior = torch.load(path+"RelCoordNet_prior.pth")
others_z_relcoord = relcoordnet_prior['relative_landmark_z_mean'][others_idx]


state = {
    'featnet_leye' : featnet_leye,
    'featnet_reye' : featnet_reye,
    'featnet_others' : featnet_others,
    'coordnet' : coordnet,
    'relcoordnet' : relcoordnet,
    'leye_z_ft' : leye_z_feat,
    'reye_z_ft' : reye_z_feat,
    'others_z_ft' : others_z_feat,
    'leye_z_cd' : leye_z_coord,
    'reye_z_cd' : reye_z_coord,
    'others_z_cd' : others_z_coord,
    'others_z_rcd' : others_z_relcoord
    }

torch.save(state, path+"COFW_state.pth")

